{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sensitivity Analysis\n",
    "\n",
    "\n",
    "Some pruning algorthims tune their hyperparameters based on the results of pruning sensitivity analysis.  Distiller support L1-norm element-wise pruning sensitivity analysis, and filter-wise pruning sensitivity analysis based on the mean L1-norm ranking of filters.\n",
    "\n",
    "## Table of Contents\n",
    "\n",
    "1. [Load a pruning sensitivity analysis file](#Load-a-pruning-sensitivity-analysis-file)\n",
    "2. [Examine parameters sensitivities](#Examine-parameters-sensitivities)<br>\n",
    "    2.1. [Plot layer sensitivities at a selected sparsity level](#Plot-layer-sensitivities-at-a-selected-sparsity-level)<br>\n",
    "    2.2. [Compare layer sensitivities](#Compare-layer-sensitivities)\n",
    "3. [Filter pruning sensitivity analysis](#Filter-pruning-sensitivity-analysis)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load a pruning sensitivity analysis file\n",
    "\n",
    "You prepare a sensitivity analysis file by invoking ```distiller.perform_sensitivity_analysis()```.  Checkout the documentation of ```distiller.perform_sensitivity_analysis()``` for more information.<br>\n",
    "Alternatively, you can use the sample ```compress_classifier.py``` application to perform sensitivity analysis on one of the supported models.  In the example below, we invoke sensitivity analysis on a pretrained Resnet18 from torchvision, using the ImageNet test dataset for evaluation. \n",
    "\n",
    "```\n",
    "$ python3 compress_classifier.py -a resnet18 ../../../data.imagenet -j 12 --pretrained --sense=element\n",
    "```\n",
    "\n",
    "The outputs of performing pruning sensitivity analysis on several different networks is available at ```../examples/sensitivity-analysis``` "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import ipywidgets as widgets\n",
    "from ipywidgets import interactive, interact, Layout\n",
    "\n",
    "df = df = pd.read_csv('../examples/sensitivity-analysis/mobilenet-imagenet/sensitivity.csv')\n",
    "#df = pd.read_csv('../examples/sensitivity-analysis/resnet18-imagenet/sensitivity.csv')\n",
    "#df = pd.read_csv('../examples/sensitivity-analysis/resnet56-cifar/sensitivity_filter_wise.csv')\n",
    "#df = pd.read_csv('../examples/sensitivity-analysis/resnet20-cifar/sensitivity_filter_wise.csv')\n",
    "\n",
    "df['sparsity'] = round(df['sparsity'], 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The code below converts the sensitivities dataframe to a sensitivities dictionary. <br> \n",
    "Using this dictionary makes it easier for us when we want to plot sensitivities."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from collections import OrderedDict\n",
    "\n",
    "def get_param_names(df):\n",
    "    return list(set(df['parameter']))\n",
    "\n",
    "def get_sensitivity_levels(df):\n",
    "    return list(set(df['sparsity']))\n",
    "\n",
    "def df2sensitivities(df):\n",
    "    param_names = get_param_names(df)\n",
    "    sparsities = get_sensitivity_levels(df)\n",
    "\n",
    "    sensitivities = {}\n",
    "    for param_name in param_names:\n",
    "        sensitivities[param_name] = OrderedDict()\n",
    "        param_stats = df[(df.parameter == param_name)]\n",
    "        \n",
    "        for row in range(len(param_stats.index)):\n",
    "            s = param_stats.iloc[[row]].sparsity\n",
    "            top1 = param_stats.iloc[[row]].top1\n",
    "            top5 = param_stats.iloc[[row]].top5\n",
    "            sensitivities[param_name][float(s)] = (float(top1), float(top5))\n",
    "    return sensitivities "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Examine parameters sensitivities\n",
    "\n",
    "After loading the sensitivity analysis CSV file into a Pandas dataframe, we can examine it.\n",
    "\n",
    "### Plot layer sensitivities at a selected sparsity level\n",
    "Use the dropdown to choose the sparsity level, and select whether you choose to view the top1 accuracies or top5.<br>\n",
    "Under the plot we display the numerical values of the accuracies, in case you want to have a closer look at the details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def view2(level, acc):\n",
    "    filtered = df[df.sparsity == level]\n",
    "    s = filtered.style.apply(highlight_min_max)\n",
    "    \n",
    "    param_names = filtered['parameter']\n",
    "    \n",
    "    # Plot the sensitivities\n",
    "    x = range(filtered[acc].shape[0])\n",
    "    y = filtered[acc].values.tolist()\n",
    "    fig = plt.figure(figsize=(20,10))\n",
    "    plt.plot(x, y, label=param_names, marker=\"o\", markersize=10, markerfacecolor=\"C1\")\n",
    "    plt.ylabel(str(acc))\n",
    "    plt.xlabel('parameter')\n",
    "    plt.xticks(rotation='vertical')\n",
    "    plt.xticks(x, param_names)\n",
    "    plt.title('Pruning Sensitivity per layer %d' % level)    \n",
    "    #return s\n",
    "\n",
    "def highlight_min_max(s):\n",
    "    \"\"\"Highlight the max and min values in the series\"\"\"\n",
    "    if s.name not in ['top1', 'top5']:\n",
    "        return ['' for v in s] \n",
    "    \n",
    "    is_max = s == s.max()\n",
    "    maxes = ['background-color: green' if v else '' for v in is_max]\n",
    "    is_min = s == s.min()\n",
    "    mins = ['background-color: red' if v else '' for v in is_min]    \n",
    "    return [h1 if len(h1)>len(h2) else h2 for (h1,h2) in zip(maxes, mins)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "sparsities = np.sort(get_sensitivity_levels(df))\n",
    "acc_radio = widgets.RadioButtons(options=['top1', 'top5'], value='top1', description='Accuracy:')\n",
    "levels_dropdown = widgets.Dropdown(description='Sparsity:', options=sparsities)\n",
    "interact(view2, level=levels_dropdown, acc=acc_radio);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sometimes we want to look at the sensitivies of a specific weights tensor:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def view_sparsity(param_name):\n",
    "    display(df[df['parameter']==param_name])\n",
    "\n",
    "param_names = sorted(df['parameter'].unique().tolist())\n",
    "param_dropdown = widgets.Dropdown(description='Parameter:', options=param_names)\n",
    "interact(view_sparsity, param_name=param_dropdown);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compare layer sensitivities\n",
    "\n",
    "Plot the pruning sensitivities of selected layers.\n",
    "<br>Select multiple parameters using SHIFT and CTRL."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# assign a different color to each parameter (otherwise, colors change on us as we make different selections)\n",
    "param_names = sorted(df['parameter'].unique().tolist())\n",
    "color_idx = np.linspace(0, 1, len(param_names))\n",
    "colors = {}  \n",
    "for i, pname in zip(color_idx, param_names):\n",
    "    colors[pname] = color= plt.get_cmap('tab20')(i)\n",
    "plt.rcParams.update({'font.size': 18})\n",
    "\n",
    "def view(weights='', acc=0):\n",
    "    sensitivities= None\n",
    "    if weights[0]=='All':\n",
    "        sensitivities = df2sensitivities(df)\n",
    "    else:\n",
    "        mask = False\n",
    "        mask = [(df.parameter == pname) for pname in weights]\n",
    "        mask = np.logical_or.reduce(mask)\n",
    "        sensitivities = df2sensitivities(df[mask])\n",
    "\n",
    "\n",
    "    # Plot the sensitivities\n",
    "    fig = plt.figure(figsize=(20,10))\n",
    "    for param_name, sensitivity in sensitivities.items():\n",
    "        sense = [values[acc] for sparsity, values in sensitivity.items()]\n",
    "        sparsities = [sparsity for sparsity, values in sensitivity.items()]\n",
    "        plt.plot(sparsities, sense, label=param_name, marker=\"o\", markersize=10, color=colors[param_name])\n",
    "\n",
    "    plt.ylabel('top1')\n",
    "    plt.xlabel('sparsity')\n",
    "    plt.title('Pruning Sensitivity')\n",
    "    #plt.legend(loc='lower center', ncol=2, mode=\"expand\", borderaxespad=0.);\n",
    "    plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), fancybox=True, shadow=True, ncol=3)\n",
    "\n",
    "items = ['All']+param_names\n",
    "w = widgets.SelectMultiple(options=items, value=[items[1]], layout=Layout(width='50%'), description='Weights:')\n",
    "acc_widget = widgets.RadioButtons(options={'top1': 0, 'top5': 1}, value=0, description='Accuracy:')\n",
    "interactive(view, acc=acc_widget, weights=w)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Filter pruning sensitivity analysis\n",
    "\n",
    "Just as we perform element-wise pruning sensitivity analysis, we can also analyze a model's filter-wise pruning sensitivity.  Although the sparsity levels are reported in percentage steps, the actual pruning level might be somewhat lower, because when we prune filters the minimum granularity of pruning is ```1/numer_of_filters```.\n",
    "\n",
    "\n",
    "We performed a filter-wise pruning sensitivity analysis on ResNet20-Cifar using the following command:\n",
    "```\n",
    "python3 compress_classifier.py -a resnet20_cifar ../../../data.cifar10/ -j 12 --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --sense=filter\n",
    "```\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_filter  = pd.read_csv('../examples/sensitivity-analysis/resnet20-cifar/sensitivity_filter_wise.csv')\n",
    "df_element = pd.read_csv('../examples/sensitivity-analysis/resnet20-cifar/sensitivity.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def view_sparsity(param_name):\n",
    "    display(df_filter[df_filter['parameter']==param_name])\n",
    "    \n",
    "param_names = sorted(df_filter['parameter'].unique().tolist())\n",
    "param_dropdown = widgets.Dropdown(description='Parameter:', options=param_names)\n",
    "interact(view_sparsity, param_name=param_dropdown);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's look at the sparsity vs. the compute:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def view_fliters(level, acc):\n",
    "    filtered = df_filter[df_filter.sparsity == level]\n",
    "    s = filtered.style.apply(highlight_min_max)    \n",
    "    param_names = filtered['parameter']\n",
    "    \n",
    "    # Plot the sensitivities\n",
    "    x = range(filtered[acc].shape[0])\n",
    "    y = filtered[acc].values.tolist()\n",
    "    fig = plt.figure(figsize=(20,10))\n",
    "    plt.plot(x, y, label=param_names, marker=\"o\", markersize=10, markerfacecolor=\"C1\")\n",
    "    plt.ylabel(str(acc))\n",
    "    plt.xlabel('parameter')\n",
    "    plt.xticks(rotation='vertical')\n",
    "    plt.xticks(x, param_names)\n",
    "    plt.title('Filter pruning sensitivity per layer %d' % level)    \n",
    "    return s\n",
    "\n",
    "df_filter['sparsity'] = round(df_filter['sparsity'], 2)\n",
    "sparsities = np.sort(get_sensitivity_levels(df_filter))\n",
    "acc_radio = widgets.RadioButtons(options=['top1', 'top5'], value='top1', description='Accuracy:')\n",
    "levels_dropdown = widgets.Dropdown(description='Sparsity:', options=sparsities)\n",
    "interact(view_fliters, level=levels_dropdown, acc=acc_radio);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
